import os
import torch
import subprocess

from pathlib import Path
from ..utils.base_model import BaseModel
from .. import logger

from .networks.dkm.models.model_zoo.DKMv3 import DKMv3

weight_path = Path(__file__).parent / 'networks' / 'dkm'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class GIM(BaseModel):
    default_conf = {
        "model_name": "gim_dkm_100h.ckpt",
        "match_threshold": 0.2,
        "checkpoint_dir": weight_path,
    }
    required_inputs = [
        "image0",
        "image1",
    ]
    # Models exported using
    # dkm_models = {
    #     "DKMv3_outdoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_outdoor.pth",
    #     "DKMv3_indoor.pth": "https://github.com/Parskatt/storage/releases/download/dkmv3/DKMv3_indoor.pth",
    # }

    def _init(self, conf):
        model_path = weight_path / conf["model_name"]

        # Download the model.
        if not model_path.exists():
            model_path.parent.mkdir(exist_ok=True)
            link = self.dkm_models[conf["model_name"]]
            cmd = ["wget", link, "-O", str(model_path)]
            logger.info(f"Downloading the DKMv3 model with `{cmd}`.")
            subprocess.run(cmd, check=True)
        # logger.info(f"Loading GIM model...")
        # self.net = DKMv3(path_to_weights=str(model_path), device=device)

        model = DKMv3(None, 672, 896, upsample_preds=True)

        checkpoints_path = str(model_path)
        state_dict = torch.load(checkpoints_path, map_location='cpu')
        if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict']
        for k in list(state_dict.keys()):
            if k.startswith('model.'):
                state_dict[k.replace('model.', '', 1)] = state_dict.pop(k)
            if 'encoder.net.fc' in k:
                state_dict.pop(k)
        model.load_state_dict(state_dict)

        self.net = model

    def _forward(self, data):
        # img0 = data["image0"].cpu().numpy().squeeze() * 255
        # img1 = data["image1"].cpu().numpy().squeeze() * 255
        # img0 = img0.transpose(1, 2, 0)
        # img1 = img1.transpose(1, 2, 0)
        # img0 = Image.fromarray(img0.astype("uint8"))
        # img1 = Image.fromarray(img1.astype("uint8"))
        # W_A, H_A = img0.size
        # W_B, H_B = img1.size
        #
        # warp, certainty = self.net.match(img0, img1, device=device)
        # matches, certainty = self.net.sample(warp, certainty)
        # kpts1, kpts2 = self.net.to_pixel_coordinates(
        #     matches, H_A, W_A, H_B, W_B
        # )

        image0, image1 = data['image0'], data['image1']
        orig_width0 = image0.shape[3]
        orig_height0 = image0.shape[2]
        orig_width1 = image1.shape[3]
        orig_height1 = image1.shape[2]
        aspect_ratio = 896 / 672
        new_width0 = max(orig_width0, int(orig_height0 * aspect_ratio))
        new_height0 = max(orig_height0, int(orig_width0 / aspect_ratio))
        new_width1 = max(orig_width1, int(orig_height1 * aspect_ratio))
        new_height1 = max(orig_height1, int(orig_width1 / aspect_ratio))
        new_width = max(new_width0, new_width1)
        new_height = max(new_height0, new_height1)
        pad_height0 = new_height - orig_height0
        pad_width0 = new_width - orig_width0
        pad_height1 = new_height - orig_height1
        pad_width1 = new_width - orig_width1
        pad_top0 = pad_height0 // 2
        pad_bottom0 = pad_height0 - pad_top0
        pad_left0 = pad_width0 // 2
        pad_right0 = pad_width0 - pad_left0
        pad_top1 = pad_height1 // 2
        pad_bottom1 = pad_height1 - pad_top1
        pad_left1 = pad_width1 // 2
        pad_right1 = pad_width1 - pad_left1
        image0 = torch.nn.functional.pad(image0, (pad_left0, pad_right0, pad_top0, pad_bottom0))
        image1 = torch.nn.functional.pad(image1, (pad_left1, pad_right1, pad_top1, pad_bottom1))
        import datetime
        print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match start')
        dense_matches, dense_certainty = self.net.match(image0, image1)
        print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.match end')
        print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample start')
        sparse_matches, mconf = self.net.sample(dense_matches, dense_certainty, self.conf["max_keypoints"])
        print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"), 'self.net.sample end')
        height0, width0 = image0.shape[-2:]
        height1, width1 = image1.shape[-2:]
        kpts0 = sparse_matches[:, :2]
        kpts1 = sparse_matches[:, 2:]
        kpts0 = torch.stack((width0 * (kpts0[:, 0] + 1) / 2, height0 * (kpts0[:, 1] + 1) / 2), dim=-1, )
        kpts1 = torch.stack((width1 * (kpts1[:, 0] + 1) / 2, height1 * (kpts1[:, 1] + 1) / 2), dim=-1, )
        b_ids, i_ids = torch.where(mconf[None])
        # before padding
        kpts0 -= kpts0.new_tensor((pad_left0, pad_top0))[None]
        kpts1 -= kpts1.new_tensor((pad_left1, pad_top1))[None]
        mask = (kpts0[:, 0] > 0) & \
               (kpts0[:, 1] > 0) & \
               (kpts1[:, 0] > 0) & \
               (kpts1[:, 1] > 0)
        mask = mask & \
               (kpts0[:, 0] <= (orig_width0 - 1)) & \
               (kpts1[:, 0] <= (orig_width1 - 1)) & \
               (kpts0[:, 1] <= (orig_height0 - 1)) & \
               (kpts1[:, 1] <= (orig_height1 - 1))
        pred = {
            'keypoints0': kpts0[i_ids],
            'keypoints1': kpts1[i_ids],
            'confidence': mconf[i_ids],
            'batch_indexes': b_ids,
        }
        scores, b_ids = pred['confidence'], pred['batch_indexes']
        kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
        pred['confidence'], pred['batch_indexes'] = scores[mask], b_ids[mask]
        pred['keypoints0'], pred['keypoints1'] = kpts0[mask], kpts1[mask]

        out = {"keypoints0": pred['keypoints0'], "keypoints1": pred['keypoints1']}
        return out
